{
 "cells": [
  {
   "cell_type": "raw",
   "source": [
    " # Costum Loss Function\n",
    " The following tutorial shows how to define a custom loss function and use it for training a model."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": ""
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f7d979c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'\n",
    "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n",
    "import tensorflow as tf\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "dafa4c11",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ampligraph\n",
    "# Benchmark datasets are under ampligraph.datasets module\n",
    "from ampligraph.datasets import load_fb15k_237\n",
    "# load fb15k-237 dataset\n",
    "dataset = load_fb15k_237()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49e03ab8",
   "metadata": {},
   "source": [
    "## Train the model using custom loss function\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "94e04064",
   "metadata": {},
   "outputs": [],
   "source": [
    "def userLoss(scores_pos, scores_neg):\n",
    "    # user defined loss - takes in 2 params and returns loss\n",
    "    neg_exp = tf.exp(scores_neg)\n",
    "    pos_exp = tf.exp(scores_pos)\n",
    "    softmax_score = pos_exp / (tf.reduce_sum(neg_exp, axis=0) + pos_exp)\n",
    "    loss = -tf.math.log(softmax_score)\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "87ed1243",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metal device set to: Apple M1 Pro\n",
      "\n",
      "systemMemory: 32.00 GB\n",
      "maxCacheSize: 10.67 GB\n",
      "\n",
      "Epoch 1/10\n",
      "29/29 [==============================] - 2s 85ms/step - loss: 6736.2095\n",
      "Epoch 2/10\n",
      "29/29 [==============================] - 1s 32ms/step - loss: 6734.3237\n",
      "Epoch 3/10\n",
      "29/29 [==============================] - 1s 33ms/step - loss: 6723.1787\n",
      "Epoch 4/10\n",
      "29/29 [==============================] - 1s 34ms/step - loss: 6672.5737\n",
      "Epoch 5/10\n",
      "29/29 [==============================] - 1s 31ms/step - loss: 6530.2490\n",
      "Epoch 6/10\n",
      "29/29 [==============================] - 1s 31ms/step - loss: 6262.3530\n",
      "Epoch 7/10\n",
      "29/29 [==============================] - 1s 31ms/step - loss: 5894.4731\n",
      "Epoch 8/10\n",
      "29/29 [==============================] - 1s 30ms/step - loss: 5490.6753\n",
      "Epoch 9/10\n",
      "29/29 [==============================] - 1s 30ms/step - loss: 5099.8218\n",
      "Epoch 10/10\n",
      "29/29 [==============================] - 1s 29ms/step - loss: 4743.1699\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x298d91840>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import the KGE model\n",
    "from ampligraph.latent_features import ScoringBasedEmbeddingModel\n",
    "\n",
    "# create the model with transe scoring function\n",
    "model = ScoringBasedEmbeddingModel(eta=1,\n",
    "                                   k=100,\n",
    "                                   scoring_type='ComplEx')\n",
    "\n",
    "\n",
    "# compile the model with loss and optimizer\n",
    "model.compile(optimizer='adam', loss=userLoss)\n",
    "\n",
    "\n",
    "dataset = load_fb15k_237()\n",
    "\n",
    "model.fit(dataset['train'],\n",
    "          batch_size=10000,\n",
    "          epochs=10)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('base')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "vscode": {
   "interpreter": {
    "hash": "2e69f3670cdad0193847aaa0b77be56c05c951fcbdd384ff882dde0464f4de76"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
